package scatter.generator.core;

import cn.hutool.core.util.ReUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.generator.AutoGenerator;
import com.baomidou.mybatisplus.generator.config.DataSourceConfig;
import com.baomidou.mybatisplus.generator.config.StrategyConfig;
import com.baomidou.mybatisplus.generator.config.builder.ConfigBuilder;
import com.baomidou.mybatisplus.generator.config.po.TableField;
import com.baomidou.mybatisplus.generator.config.po.TableInfo;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import scatter.common.rest.tools.SpringContextHolder;
import scatter.generator.ext.ITableInfoExt;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/**
 * 扩展自动生成器
 * Created by yangwei
 * Created at 2021/2/24 9:44
 */
@Scope(ConfigurableBeanFactory.SCOPE_PROTOTYPE)
public class CustomAutoGenerator extends AutoGenerator {

    ITableFieldMapStruct iTableFieldMapStruct = ITableFieldMapStruct.INSTANCE;

    private ITableInfoExt iTableInfoExt;

    public void setiTableInfoExt(ITableInfoExt iTableInfoExt) {
        this.iTableInfoExt = iTableInfoExt;
    }

    /**
     * 仅用来获取表信息
     * @param dataSource
     * @return
     */
    public List<TableInfo> customGetAllTableInfoList(DataSourceConfig dataSource, StrategyConfig strategy) {
        super.setDataSource(dataSource);
        config = new ConfigBuilder(getPackageInfo(), dataSource, strategy, getTemplate(), getGlobalConfig());
        pretreatmentConfigBuilder(config);
        return getAllTableInfoList(config);
    }

    @Override
    protected List<TableInfo> getAllTableInfoList(ConfigBuilder config) {
        List<TableInfo> r =  super.getAllTableInfoList(config);
        // 如果已经处理过不再处理
        for (TableInfo tableInfo : r) {
            if(tableInfo instanceof CustomTableInfo){
                return r;
            }
        }
        List<TableInfo> newR = new ArrayList<>(r.size());
        CustomTableInfo customTableInfo = null;
        for (TableInfo tableInfo : r) {
            customTableInfo = iTableFieldMapStruct.map(tableInfo);
            // 表注释处理
            String tableComment = StrUtil.removeSuffix(tableInfo.getComment(), "关系表");
            tableComment = StrUtil.removeSuffix(tableComment, "表");
            customTableInfo.setCommentSimple(tableComment);
            // rel支持
            String[] andRel = tableComment.split("和");
            if (andRel.length == 2) {
                customTableInfo.setCommentRel1(andRel[0]);
                customTableInfo.setCommentRel2(andRel[1]);

            }
            fieldsHandle(customTableInfo);
            newR.add(customTableInfo);

        }
        return newR;
    }
    private void fieldsHandle(CustomTableInfo tableInfo ) {

        if (iTableInfoExt == null) {
            try {
                iTableInfoExt = SpringContextHolder.getBean(ITableInfoExt.class);
            } catch (Exception e) {
            }
        }

        if (iTableInfoExt != null) {
            iTableInfoExt.tableInfoExt(tableInfo);
        }
        String tableFieldsSql = String.format(getDataSource().getDbQuery().tableFieldsSql(), tableInfo.getName());
        List<TableField> fields = new ArrayList<>();
        try (
                Connection conn = getDataSource().getConn();
                PreparedStatement preparedStatement = conn.prepareStatement(tableFieldsSql);
                ResultSet results = preparedStatement.executeQuery()) {
            while (results.next()) {
                String columnName = results.getString(getDataSource().getDbQuery().fieldName());
                TableField tableField = tableInfo.getFields().stream().filter(f -> f.getColumnName().equals(columnName)).findFirst().orElse(null);
                if (tableField != null) {
                    CustomTableField customeTableField = iTableFieldMapStruct.map(tableField);
                    customeTableField.setRequired("NO".equals(results.getString("Null")));
                    customeTableField.setUnique("UNI".equals(results.getString("Key")));

                    customeTableField.setCommentSimple(Optional.ofNullable(tableField.getComment()).map(c -> c.split(",")[0]).map(c -> c.split("，")[0]).orElse(null));
                    customeTableField.setQueryLike(Optional.ofNullable(tableField.getComment()).map(c -> c.contains("模糊查询")).orElse(false));
                    customeTableField.setOrderBy(Optional.ofNullable(tableField.getComment()).map(c -> c.contains("排序")).orElse(false));
                    customeTableField.setOrderByAsc(Optional.ofNullable(tableField.getComment()).map(c -> c.contains("升序")).orElse(false));
                    customeTableField.setForeignKey(Optional.ofNullable(tableField.getComment()).map(c -> c.contains("外键")).orElse(false));

                    setLength(customeTableField, results);

                    if (iTableInfoExt != null) {
                        iTableInfoExt.tableColumnExt(customeTableField);
                    }
                    fields.add(customeTableField);
                }

            }
        } catch (SQLException e) {
            System.err.println("SQL Exception：" + e.getMessage());
        }
        tableInfo.setFields(fields);

    }
    /**
     * 设置字段长度
     * @param customeTableField
     * @param results
     */
    private void setLength(CustomTableField customeTableField, ResultSet results) throws SQLException {
        String type = results.getString("Type");
        String length = ReUtil.get("(?<=\\()(\\S+)(?=\\))", type, 0);
        if (length == null) {
            return;
        }
        String[] lengths = length.split(",");
        try {
            customeTableField.setLength(Integer.parseInt(lengths[0]));
        } catch (Exception e) {
        }
        try {
            customeTableField.setFractionLength(Integer.parseInt(lengths[1]));
        } catch (Exception e) {
        }

    }
}
